{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "44fc9cd0",
   "metadata": {},
   "source": [
    "(using-built-in-training-function)=\n",
    "# Using the built-in training function"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0e900797",
   "metadata": {},
   "source": [
    "The MLRun [Function Hub](https://www.mlrun.org/hub/) includes, among other things, training functions. The most commonly used function for training is [`auto_trainer`](https://github.com/mlrun/functions/tree/development/auto_trainer), which includes the following handlers:\n",
    "\n",
    "- [Train](#train)\n",
    "- [Evaluate](#evaluate)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "40860be8",
   "metadata": {},
   "source": [
    "## Train\n",
    "\n",
    "The main and default handler of any training function is called `\"train\"`. In the Auto Trainer this handler performs\n",
    "an ML training function using SciKit-Learn's API, meaning the function follows the structure below:\n",
    "\n",
    "1. **Get the data**: Get the dataset passed to a local path.\n",
    "2. **Split the data into datasets**: Split the given data into a training set and a testing set.\n",
    "3. **Get the model**: Initialize a model instance out of a given class or load a provided model<br>\n",
    "   The supported classes are anything based on `sklearn.Estimator`, `xgboost.XGBModel`, `lightgbm.LGBMModel`, including custom code as well.\n",
    "4. **Train**: Call the model's `fit` method to train it on the training set.\n",
    "5. **Test**: Test the model on the testing set.\n",
    "6. **Log**: Calculate the metrics and produce the artifacts to log the results and plots.\n",
    "\n",
    "MLRun orchestrates all of the above steps. The training is done with the shortcut function `apply_mlrun` that\n",
    "enables the automatic logging and additional features."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "dd64cb0a",
   "metadata": {},
   "source": [
    "To start, run `import mlrun` and create a project:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "33166c4e",
   "metadata": {},
   "outputs": [],
   "source": [
    "import mlrun\n",
    "\n",
    "# Set the base project name\n",
    "project_name_base = \"training-test\"\n",
    "\n",
    "# Initialize the MLRun project object\n",
    "project = mlrun.get_or_create_project(\n",
    "    project_name_base, context=\"./\", user_project=True\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "13cdff93",
   "metadata": {},
   "source": [
    "Next, import the Auto Trainer from the Function Hub using MLRun's `import_function` function:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "14bef14d",
   "metadata": {},
   "outputs": [],
   "source": [
    "auto_trainer = project.set_function(mlrun.import_function(\"hub://auto_trainer\"))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "bd9a5434",
   "metadata": {},
   "source": [
    "The following example trains a Random Forest model:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "587efb91",
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset_url = \"https://s3.wasabisys.com/iguazio/data/function-marketplace-data/xgb_trainer/classifier-data.csv\"\n",
    "\n",
    "train_run = auto_trainer.run(\n",
    "    handler=\"train\",\n",
    "    inputs={\"dataset\": dataset_url},\n",
    "    params={\n",
    "        # Model parameters:\n",
    "        \"model_class\": \"sklearn.ensemble.RandomForestClassifier\",\n",
    "        \"model_kwargs\": {\n",
    "            \"max_depth\": 8\n",
    "        },  # Could be also passed as \"MODEL_max_depth\": 8\n",
    "        \"model_name\": \"MyModel\",\n",
    "        # Dataset parameters:\n",
    "        \"drop_columns\": [\"feat_0\", \"feat_2\"],\n",
    "        \"train_test_split_size\": 0.2,\n",
    "        \"random_state\": 42,\n",
    "        \"label_columns\": \"labels\",\n",
    "    },\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a97e29e0",
   "metadata": {},
   "source": [
    "### Outputs\n",
    "\n",
    "`train_run.outputs` returns all the outputs. The outputs are:\n",
    "\n",
    "* **Trained model**: The trained model is logged as a `ModelArtifact` with all the following artifacts registered\n",
    "  to it.\n",
    "* **Test dataset**: The test set used to test the model post training is logged as a `DatasetArtifact`.\n",
    "* **Plots**: Informative plots regarding the model like confusion matrix and features importance are drawn and logged\n",
    "  as `PlotArtifact`s.\n",
    "* **Results**: List of all the calculations of metrics tested on the testing set.\n",
    "\n",
    "For instance, calling `train_run.artifact('confusion-matrix').show()` shows the following confusion matrix:\n",
    "\n",
    "![confusion matrix](../_static/images/confusion-matrix.png)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b5a60d1e",
   "metadata": {},
   "source": [
    "### Parameters\n",
    "\n",
    "To view the parameters of `train`, expand the section below:\n",
    "\n",
    "````{dropdown} train handler parameters:\n",
    "\n",
    "**Model Parameters**\n",
    "\n",
    "*Parameters to initialize a new model object or load a logged one for retraining.*\n",
    "\n",
    "* `model_class`: `str` &mdash; The class of the model to initialize. Can be a module path like\n",
    "  `\"sklearn.linear_model.LogisticRegression\"` or a custom model passed through the custom objects parameters below.\n",
    "  Only one of `model_class` and `model_path` can be given.\n",
    "* `model_path`: `str` &mdash; A `ModelArtifact` URI to load and retrain. Only one of `model_class` and `model_path` can be\n",
    "  given.\n",
    "* `model_kwargs`: `dict` &mdash; Additional parameters to pass onto the initialization of the model object (the model's class\n",
    "  `__init__` method).\n",
    "\n",
    "**Data parameters**\n",
    "\n",
    "*Parameters to get a dataset and prepare it for training, splitting into training and testing if required.*\n",
    "\n",
    "* `dataset`: `Union[str, list, dict]` &mdash; The dataset to train the model on.\n",
    "  * Can be passed as part of `inputs` to be parsed as `mlrun.DataItem`, meaning it supports either a URI or a\n",
    "    FeatureVector.\n",
    "  * Can be passed as part of `params`, meaning it can be a `list` or a `dict`.\n",
    "* `drop_columns`: `Union[str, int, List[str], List[int]]` &mdash; Columns to drop from the dataset. Can be passed as strings\n",
    "  representing the column names or integers representing the column numbers.\n",
    "* `test_set`: `Union[str, list, dict]` &mdash; The test set to test the model with post training. Notice only one of\n",
    "  `test_set` or `train_test_split_size` is expected.\n",
    "  * Can be passed as part of `inputs` to be parsed as `mlrun.DataItem`, meaning it supports either a URI or a\n",
    "    FeatureVector.\n",
    "  * Can be passed as part of `params`, meaning it can be a `list` or a `dict`.\n",
    "* `train_test_split_size`: `float` = `0.2` &mdash; The proportion of the dataset to include in the test split. The size of the\n",
    "  Training set is set to the complement of this value. Must be between 0.0 and 1.0. Defaults to 0.2\n",
    "* `label_columns`: `Union[str, int, List[str], List[int]]` &mdash; The target label(s) of the column(s) in the dataset. Can\n",
    "  be passed as strings representing the column names or integers representing the column numbers.\n",
    "* `random_state`: `int` - Random state (seed) for `train_test_split`.\n",
    "\n",
    "**Train parameters**\n",
    "\n",
    "*Parameters to pass to the `fit` method of the model object.*\n",
    "\n",
    "* `train_kwargs`: `dict` &mdash; Additional parameters to pass onto the `fit` method.\n",
    "\n",
    "**Logging parameters**\n",
    "\n",
    "*Parameters to control the automatic logging feature of MLRun. You can adjust the logging outputs as relevant and if\n",
    "not passed, a default list of artifacts and metrics is produced and calculated.*\n",
    "\n",
    "* `model_name`: `str` = `\"model`\" &mdash; The model’s name to use for storing the model artifact, defaults to ‘model’.\n",
    "* `tag`: `str` &mdash; The model’s tag to log with.\n",
    "* `sample_set`: `Union[str, list, dict]` &mdash; A sample set of inputs for the model for logging its stats alongside the model in\n",
    "  favor of model monitoring. If not given, the training set is used instead.\n",
    "  * Can be passed as part of `inputs` to be parsed as `mlrun.DataItem`, meaning it supports either a URI or a\n",
    "    FeatureVector.\n",
    "  * Can be passed as part of `params`, meaning it can be a `list` or a `dict`.\n",
    "* `_artifacts`: `Dict[str, Union[list, dict]]` &mdash; Additional artifacts to produce post training. See the\n",
    "  `ArtifactsLibrary` of the desired framework to see the available list of artifacts.\n",
    "* `_metrics`: `Union[List[str], Dict[str, Union[list, dict]]]` &mdash; Additional metrics to calculate post training. See how\n",
    "  to pass metrics and custom metrics in the `MetricsLibrary` of the desired framework.\n",
    "* `apply_mlrun_kwargs`: `dict` &mdash; Framework specific `apply_mlrun` key word arguments. Refer to the framework of choice\n",
    "  to know more ([SciKit-Learn](), [XGBoost]() or [LightGBM]())\n",
    "\n",
    "**Custom objects parameters**\n",
    "\n",
    "*Parameters to include custom objects like custom model class, metric code and artifact plan. Keep in mind that the\n",
    "model artifact created is logged with the custom objects, so if `model_path` is used, the custom objects used to\n",
    "train it are not required for loading it, it happens automatically.*\n",
    "\n",
    "* `custom_objects_map`: `Union[str, Dict[str, Union[str, List[str]]]]` &mdash; A map of all the custom objects required for\n",
    "  loading, training and testing the model. Can be passed as a dictionary or a json file path. Each key is a path to a\n",
    "  python file and its value is the custom object name to import from it. If multiple objects needed to be imported from\n",
    "  the same py file a list can be given. For example:\n",
    "  ```python\n",
    "  {\n",
    "      \"/.../custom_model.py\": \"MyModel\",\n",
    "      \"/.../custom_objects.py\": [\"object1\", \"object2\"]\n",
    "  }\n",
    "  ```\n",
    "  All the paths are accessed from the given 'custom_objects_directory', meaning each py file is read from\n",
    "  'custom_objects_directory/<MAP VALUE>'. If the model path given is of a store object, the custom objects map is\n",
    "  read from the logged custom object map artifact of the model. \n",
    "  \n",
    "   ```{admonition} Note\n",
    "   The custom objects are imported in the order they came in this dictionary (or json). If a custom \n",
    "   object is dependent on another, make sure to put it below the one it relies on.\n",
    "   ```  \n",
    "  \n",
    "  \n",
    "* `custom_objects_directory`: Path to the directory with all the python files required for the custom objects. Can be\n",
    "  passed as a zip file as well (and are extracted during the start of the run).\n",
    "\n",
    " \n",
    "   ```{admonition} Note\n",
    "   The parameters for additional arguments `model_kwargs`, `train_kwargs` and `apply_mlrun_kwargs` can be\n",
    "  also passed in the global `kwargs` with the matching prefixes: `\"MODEL_\"`, `\"TRAIN_\"`, `\"MLRUN_\"`.\n",
    "   ```\n",
    "````"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9f147e1e",
   "metadata": {},
   "source": [
    "(auto_trainer_evaluate)=\n",
    "\n",
    "## Evaluate\n",
    "\n",
    "The `\"evaluate\"` handler is used to test the model on a given testing set and log its results. This is a common phase in\n",
    "every model lifecycle and should be done periodically on updated testing sets to confirm that your model is still relevant.\n",
    "The function uses SciKit-Learn's API for evaluation, meaning the function follows the structure below:\n",
    "\n",
    "1. **Get the data**: Get the testing dataset passed to a local path.\n",
    "2. **Get the model**: Get the model object out of the `ModelArtifact` URI.\n",
    "3. **Predict**: Call the model's `predict` (and `predict_proba` if needed) method to test it on the testing set.\n",
    "4. **Log**: Test the model on the testing set and log the results and artifacts.\n",
    "\n",
    "MLRun orchestrates all of the above steps. The evaluation is done with the shortcut function `apply_mlrun` that\n",
    "enables the automatic logging and further features."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6f461746",
   "metadata": {},
   "source": [
    "To evaluate the test-set, use the following command:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0cba141e",
   "metadata": {},
   "outputs": [],
   "source": [
    "evaluate_run = auto_trainer.run(\n",
    "    handler=\"evaluate\",\n",
    "    inputs={\"dataset\": train_run.outputs[\"test_set\"]},\n",
    "    params={\n",
    "        \"model\": train_run.outputs[\"model\"],\n",
    "        \"label_columns\": \"labels\",\n",
    "    },\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "923bfedb",
   "metadata": {},
   "source": [
    "### Outputs\n",
    "\n",
    "`evaluate_run.outputs` returns all the outputs. The outputs are:\n",
    "\n",
    "* **Evaluated model**: The evaluated model's `ModelArtifact`  is updated with all the following artifacts registered\n",
    "  to it.\n",
    "* **Test dataset**: The test set used to test the model post-training is logged as a `DatasetArtifact`.\n",
    "* **Plots**: Informative plots regarding the model like confusion matrix and features importance are drawn and logged\n",
    "  as `PlotArtifact`s.\n",
    "* **Results**: List of all the calculations of metrics tested on the testing set."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9503b38d",
   "metadata": {},
   "source": [
    "### Parameters\n",
    "\n",
    "To view the parameters of `evaluate`, expand the section below:\n",
    "\n",
    "````{dropdown} evaluate handler parameters:\n",
    "\n",
    "**Model Parameters**\n",
    "\n",
    "*Parameters to load a logged model.*\n",
    "\n",
    "* `model_path`: `str` &mdash; A `ModelArtifact` URI to load.\n",
    "\n",
    "**Data parameters**\n",
    "\n",
    "*Parameters to get a dataset and prepare it for training, splitting into training and testing if required.*\n",
    "\n",
    "* `dataset`: `Union[str, list, dict]` &mdash; The dataset to train the model on.\n",
    "  * Can be passed as part of `inputs` to be parsed as `mlrun.DataItem`, meaning it supports either a URI or a\n",
    "    FeatureVector.\n",
    "  * Can be passed as part of `params`, meaning it can be a `list` or a `dict`.\n",
    "* `drop_columns`: `Union[str, int, List[str], List[int]]` &mdash; columns to drop from the dataset. Can be passed as strings\n",
    "  representing the column names or integers representing the column numbers.\n",
    "* `label_columns`: `Union[str, int, List[str], List[int]]` &mdash; The target label(s) of the column(s) in the dataset. Can\n",
    "  be passed as strings representing the column names or integers representing the column numbers.\n",
    "\n",
    "**Predict parameters**\n",
    "\n",
    "*Parameters to pass to the `predict` method of the model object.*\n",
    "\n",
    "* `predict_kwargs`: `dict` &mdash; Additional parameters to pass onto the `predict` method.\n",
    "\n",
    "**Logging parameters**\n",
    "\n",
    "*Parameters to control the automatic logging feature of MLRun. You can adjust the logging outputs as relevant, and if\n",
    "not passed, a default list of artifacts and metrics is produced and calculated.*\n",
    "\n",
    "* `_artifacts`: `Dict[str, Union[list, dict]]` &mdash; Additional artifacts to produce post training. See the\n",
    "  `ArtifactsLibrary` of the desired framework to see the available list of artifacts.\n",
    "* `_metrics`: `Union[List[str], Dict[str, Union[list, dict]]]` &mdash; Additional metrics to calculate post training. See how\n",
    "  to pass metrics and custom metrics in the `MetricsLibrary` of the desired framework.\n",
    "* `apply_mlrun_kwargs`: `dict` &mdash; Framework specific `apply_mlrun` key word arguments. Refer to the framework of choice\n",
    "  to know more ([SciKit-Learn](), [XGBoost]() or [LightGBM]()).\n",
    "\n",
    "**Custom objects parameters**\n",
    "\n",
    "*Parameters to include custom objects for the evaluation like custom metric code and artifact plans. Keep in mind that\n",
    "the custom objects used to train the model are not required for loading it, it happens automatically.*\n",
    "\n",
    "* `custom_objects_map`: `Union[str, Dict[str, Union[str, List[str]]]]` &mdash; A map of all the custom objects required for\n",
    "  loading, training and testing the model. Can be passed as a dictionary or a json file path. Each key is a path to a\n",
    "  python file and its value is the custom object name to import from it. If multiple objects needed to be imported from\n",
    "  the same py file a list can be given. For example:\n",
    "  ```python\n",
    "  {\n",
    "      \"/.../custom_metric.py\": \"MyMetric\",\n",
    "      \"/.../custom_plans.py\": [\"plan1\", \"plan2\"]\n",
    "  }\n",
    "  ```\n",
    "  All the paths are accessed from the given 'custom_objects_directory', meaning each py file is read from the \n",
    "  'custom_objects_directory/<MAP VALUE>'. If the model path given is of a store object, the custom objects map is\n",
    "  read from the logged custom object map artifact of the model.\n",
    "  \n",
    "   ```{admonition} Note\n",
    "   The custom objects are imported in the order they came in this dictionary (or json). If a \n",
    "   custom object is depended on another, make sure to put it below the one it relies on.\n",
    "   ```\n",
    "   \n",
    "* `custom_objects_directory` &mdash; Path to the directory with all the python files required for the custom objects. Can be\n",
    "  passed as a zip file as well (iti is extracted during the start of the run).\n",
    "\n",
    "   ```{admonition} Note\n",
    "   The parameters for additional arguments `predict_kwargs` and `apply_mlrun_kwargs` can be also \n",
    "   passed in the global `kwargs` with the matching prefixes: `\"PREDICT_\"`, `\"MLRUN_\"`.\n",
    "   ```\n",
    "   \n",
    "````"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.7"
  },
  "vscode": {
   "interpreter": {
    "hash": "916dbcbb3f70747c44a77c7bcd40155683ae19c65e1c03b4aa3499c5328201f1"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
